{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e3wdiK5fGUoZ"
      },
      "source": [
        "# Advanced RAG with guided generation\n",
        "\n",
        "[txtai](https://github.com/neuml/txtai) is an all-in-one embeddings database for semantic search, LLM orchestration and language model workflows.\n",
        "\n",
        "A standard RAG process typically runs a single vector search query and returns the closest matches. Those matches are then passed into a LLM prompt and used to limit the context and help ensure more factually correct answers are generated. This works well with most simple cases. More complex use cases, require a more advanced approach.\n",
        "\n",
        "This notebook will demonstrate how constrained or guided generation can be applied to better control LLM output."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p8BbfjrhH-V2"
      },
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install `txtai` and all dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-OXsTQgaGQPM"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai#egg=txtai autoawq outlines"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define the RAG process\n",
        "\n",
        "The first step we'll take is to define the RAG process. The following code creates a LLM instance, defines method that takes a question and context then prompts an LLM."
      ],
      "metadata": {
        "id": "Wj7SlK_ZdwN9"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zC_yDuyKk8ZY"
      },
      "outputs": [],
      "source": [
        "from txtai import LLM\n",
        "\n",
        "llm = LLM(\"TheBloke/Mistral-7B-OpenOrca-AWQ\")\n",
        "\n",
        "def rag(question, text):\n",
        "    prompt = f\"\"\"<|im_start|>system\n",
        "    You are a friendly assistant. You answer questions from users.<|im_end|>\n",
        "    <|im_start|>user\n",
        "    Answer the following question using only the context below. Only include information specifically discussed.\n",
        "\n",
        "    question: {question}\n",
        "    context: {text} <|im_end|>\n",
        "    <|im_start|>assistant\n",
        "    \"\"\"\n",
        "\n",
        "    return llm(prompt, maxlength=4096)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's run a simple RAG call to get the idea of the default behavior."
      ],
      "metadata": {
        "id": "M2NQHT8nfZWJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Manually generated context. Replace with an Embedding search or other request. See prior examples on txtai's documentation site for more.\n",
        "context = \"\"\"\n",
        "England's terrain chiefly consists of low hills and plains, especially in the centre and south.\n",
        "The Battle of Hastings was fought on 14 October 1066 between the Norman army of William, the Duke of Normandy, and an English army under the Anglo-Saxon King Harold Godwinson\n",
        "Bounded by the Atlantic Ocean on the east, Brazil has a coastline of 7,491 kilometers (4,655 mi).\n",
        "Spain pioneered the exploration of the New World and the first circumnavigation of the globe.\n",
        "Christopher Columbus lands in the Caribbean in 1492.\n",
        "\"\"\"\n",
        "\n",
        "print(rag(\"List the countries discussed\", context))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yVXLDj5wfkqF",
        "outputId": "984f5e15-1949-413a-d4f0-4e55c991ede2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "1. England\n",
            "2. Brazil\n",
            "3. Spain\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Guided Generation\n",
        "\n",
        "The next step is defining how to guide generation. For this step, we'll use the [Outlines](https://github.com/outlines-dev/outlines) library. Outlines is a library for controlling how tokens are generated. It applies logic to enforce schemas, regular expressions and/or specific output formats such as JSON.\n",
        "\n",
        "For our first example, we'll guide generation with a model that has answers and citations. With this multi-answer and multi-citation model, we can generate multiple answers along with associated references on how those answers were derived."
      ],
      "metadata": {
        "id": "co2ZA6hyeD3e"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import List\n",
        "\n",
        "from outlines.integrations.transformers import JSONPrefixAllowedTokens\n",
        "from pydantic import BaseModel\n",
        "\n",
        "class Response(BaseModel):\n",
        "    answers: List[str]\n",
        "    citations: List[str]\n",
        "\n",
        "# Define method that guides LLM generation\n",
        "prefix_allowed_tokens_fn=JSONPrefixAllowedTokens(\n",
        "    schema=Response,\n",
        "    tokenizer_or_pipe=llm.generator.llm.pipeline.tokenizer,\n",
        "    whitespace_pattern=r\" ?\"\n",
        ")\n",
        "\n",
        "def rag(question, text):\n",
        "    prompt = f\"\"\"<|im_start|>system\n",
        "    You are a friendly assistant. You answer questions from users.<|im_end|>\n",
        "    <|im_start|>user\n",
        "    Answer the following question using only the context below. Only include information specifically discussed.\n",
        "\n",
        "    question: {question}\n",
        "    context: {text} <|im_end|>\n",
        "    <|im_start|>assistant\n",
        "    \"\"\"\n",
        "\n",
        "    return llm(prompt, maxlength=4096, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)\n"
      ],
      "metadata": {
        "id": "5jB_9SMhEZut"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Couple things to unpack here.\n",
        "\n",
        "First, note the method `prefix_allowed_tokens_fn`. This method applies a [Pydantic](https://github.com/pydantic/pydantic) model to constrain/guide how the LLM generates tokens. Next, see how that constrain can be applied to txtai's LLM pipeline.\n",
        "\n",
        "Let's try it out."
      ],
      "metadata": {
        "id": "ifCN1dpjgvTi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import json\n",
        "\n",
        "json.loads(rag(\"List the countries discussed\", context))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Uf7_Co4VL0uu",
        "outputId": "ed752198-4074-4220-83f7-358ab9237bce"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'answers': ['England', 'Brazil', 'Spain'],\n",
              " 'citations': [\"England's terrain chiefly consists of low hills and plains, especially in the centre and south.\",\n",
              "  'The Battle of Hastings was fought on 14 October 1066 between the Norman army of William, the Duke of Normandy, and an English army under the Anglo-Saxon King Harold Godwinson.',\n",
              "  'Bounded by the Atlantic Ocean on the east, Brazil has a coastline of 7,491 kilometers (4,655 mi).',\n",
              "  'Spain pioneered the exploration of the New World and the first circumnavigation of the globe.',\n",
              "  'Christopher Columbus lands in the Caribbean in 1492.']}"
            ]
          },
          "metadata": {},
          "execution_count": 157
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This is pretty 🔥\n",
        "\n",
        "See how not only are the answers generated as they were previously but the answers are now list of answers. And there is a list of citations supporting how the answers were generated! This is also valid JSON."
      ],
      "metadata": {
        "id": "rGbvrtZmhi4Y"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Extracting information models\n",
        "\n",
        "In our last example, we'll define a more complex model to help with extracting structured information."
      ],
      "metadata": {
        "id": "W0rsUfwniOgn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Response(BaseModel):\n",
        "    countries: List[str]\n",
        "    geography: List[str]\n",
        "    years: List[str]\n",
        "    people: List[str]\n",
        "\n",
        "prefix_allowed_tokens_fn=JSONPrefixAllowedTokens(\n",
        "    schema=Response,\n",
        "    tokenizer_or_pipe=llm.generator.llm.pipeline.tokenizer,\n",
        "    whitespace_pattern=r\" ?\"\n",
        ")\n",
        "\n",
        "def rag(question, text):\n",
        "    prompt = f\"\"\"<|im_start|>system\n",
        "    You are a friendly assistant. You answer questions from users.<|im_end|>\n",
        "    <|im_start|>user\n",
        "    Answer the following question using only the context below. Only include information specifically discussed.\n",
        "\n",
        "    question: {question}\n",
        "    context: {text} <|im_end|>\n",
        "    <|im_start|>assistant\n",
        "    \"\"\"\n",
        "\n",
        "    return llm(prompt, maxlength=4096, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)\n"
      ],
      "metadata": {
        "id": "XHieIoxRgPSU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "json.loads(rag(\"List the entities discussed\", context))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7_RpO3v4ib_-",
        "outputId": "c3d394db-3f2b-4f60-d2b6-754f8247f929"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'countries': ['England', 'Brazil', 'Spain'],\n",
              " 'geography': ['low hills and plains', 'Atlantic Ocean', 'New World'],\n",
              " 'years': ['1066', '1492'],\n",
              " 'people': ['William, the Duke of Normandy',\n",
              "  'Anglo-Saxon King Harold Godwinson',\n",
              "  'Christopher Columbus']}"
            ]
          },
          "metadata": {},
          "execution_count": 159
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This is extremely guided generation. This is constraining the output of the LLM into a very specific model of information. It's quite impressive and simple to get started with!"
      ],
      "metadata": {
        "id": "fDCFeFjSlveR"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Wrapping up\n",
        "\n",
        "Guided generation adds a number of different options to the RAG toolkit. It's a flexible approach that can enable more advanced functionality and/or fine-tuned control over how content is created.\n",
        "\n",
        "It does add overhead which may or may not be acceptable depending on the use case. Expect new methods with improved efficiency and accuracy coming in the future. The space continues to advance forward fast!"
      ],
      "metadata": {
        "id": "2LQTekhsl75H"
      }
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.18"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
